package com.luqili.utils.pub.sql;

import com.luqili.utils.pub.json.JsonUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.ibatis.jdbc.SqlRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 基础公共DAO
 * <li>支持事务标签
 *
 * @author luqili
 */
public abstract class LuBaseDAO {
    protected final Logger logger = LoggerFactory.getLogger(this.getClass());

//    @Autowired
//    DataSource dataSource;

    public LuBaseDAO() {
    }

    public abstract DataSource getDataSource();

    /**
     * 执行更新SQL
     *
     * @param sql
     * @param params
     * @return
     */
    public Integer update(String sql, List<Object> params) {
        DataSource dataSource = getDataSource();
        if (dataSource == null) {
            throw new LuDBException("数据库未初始化");
        }
        try (Connection ct = dataSource.getConnection()) {
            SqlRunner sr = new SqlRunner(ct);
            return sr.update(sql, params.toArray());
        } catch (SQLException e) {
            logger.error("\nSQL语句执行错误:" + sql + "\n参数:" + JsonUtils.toJson(params));
            throw new LuDBException("查询出错:" + e.getMessage(), e);
        }
    }

    /**
     * 执行插入语句
     *
     * @param sql
     * @param params
     * @return
     */
    public Integer insert(String sql, List<Object> params) {
        DataSource dataSource = getDataSource();
        if (dataSource == null) {
            throw new LuDBException("数据库未初始化");
        }
        try (Connection ct = dataSource.getConnection()) {
            SqlRunner sr = new SqlRunner(ct);
            return sr.insert(sql, params.toArray());
        } catch (SQLException e) {
            logger.error("\nSQL语句执行错误:" + sql + "\n参数:" + JsonUtils.toJson(params));
            throw new LuDBException("查询出错:" + e.getMessage(), e);
        }
    }

    /**
     * 查询一个对象
     *
     * @param sql
     * @param params
     * @return
     */
    public Integer selectOneCount(String sql, List<Object> params) {
        DataSource dataSource = getDataSource();
        if (dataSource == null) {
            throw new LuDBException("数据库未初始化");
        }
        try (Connection ct = dataSource.getConnection()) {
            SqlRunner sr = new SqlRunner(ct);
            Map<String, Object> result = sr.selectOne(sql, params.toArray());
            if (result != null) {
                for (Object obj : result.values()) {
                    if (obj instanceof Integer || obj instanceof Long) {
                        return NumberUtils.toInt(obj.toString());
                    }
                }
            }
        } catch (SQLException e) {
            logger.error("\nSQL语句执行错误:" + sql + "\n参数:" + JsonUtils.toJson(params));
            throw new LuDBException("查询出错:" + e.getMessage(), e);
        }
        return 0;
    }

    /**
     * 查询一条记录
     *
     * @param sql
     * @param params
     * @return
     */
    public Map<String, Object> selectOne(String sql, List<Object> params) {
        DataSource dataSource = getDataSource();
        if (dataSource == null) {
            throw new LuDBException("数据库未初始化");
        }
        try (Connection ct = dataSource.getConnection()) {
            SqlRunner sr = new SqlRunner(ct);
            Map<String, Object> result = null;
            if (params == null || params.isEmpty()) {
                result = sr.selectOne(sql);
            } else {
                result = sr.selectOne(sql, params.toArray());
            }
            return this.formatDataColumns(sql, result);
        } catch (SQLException e) {
            logger.error("\nSQL语句执行错误:" + sql + "\n参数:" + JsonUtils.toJson(params));
            throw new LuDBException("查询出错:" + e.getMessage(), e);
        }
    }

    public List<Map<String, Object>> selectAll(String sql, List<Object> params) {
        DataSource dataSource = getDataSource();
        if (dataSource == null) {
            throw new LuDBException("数据库未初始化");
        }
        try (Connection ct = dataSource.getConnection()) {
            SqlRunner sr = new SqlRunner(ct);
            List<Map<String, Object>> result = sr.selectAll(sql, params.toArray());
            return this.formatDataColumns(sql, result);
        } catch (SQLException e) {
            logger.error("\nSQL语句执行错误:" + sql + "\n参数:" + JsonUtils.toJson(params));
            throw new LuDBException("查询出错:" + e.getMessage(), e);
        }

    }

    /**
     * 执行SQL
     *
     * @param sql
     */
    public void runSQL(String sql) {
        DataSource dataSource = getDataSource();
        if (dataSource == null) {
            throw new LuDBException("数据库未初始化");
        }
        try (Connection ct = dataSource.getConnection()) {
            SqlRunner sr = new SqlRunner(ct);
            sr.run(sql);
        } catch (SQLException e) {
            logger.error("\nSQL语句执行错误:" + sql);
            throw new LuDBException("查询出错:" + e.getMessage(), e);
        }
    }

    /**
     * 还原字段大小写
     *
     * @param sql
     * @param ltmap
     * @return
     */
    private List<Map<String, Object>> formatDataColumns(String sql, List<Map<String, Object>> ltmap) {
        List<Map<String, Object>> result = new ArrayList<>();
        if (ltmap == null || ltmap.size() < 1) {
            return result;
        }
        for (Map<String, Object> map : ltmap) {
            result.add(this.formatDataColumns(sql, map));
        }
        return result;
    }

    /**
     * 还原SQL中字段的大小写
     *
     * @param sql
     * @param map
     * @return
     */
    private Map<String, Object> formatDataColumns(String sql, Map<String, Object> map) {
        Map<String, Object> result = new HashMap<>();
        if (map == null || map.isEmpty()) {
            return result;
        }
        List<String> keys = new ArrayList<String>();
        for (String v1 : sql.split(" |,")) {
            if (StringUtils.isNotBlank(v1)) {
                keys.add(v1);
            }
        }
        for (String key : map.keySet()) {
            String key1 = null;
            for (String v1 : keys) {
                if (StringUtils.equalsIgnoreCase(v1, key)) {
                    key1 = v1;
                }
            }
            if (StringUtils.isBlank(key1)) {
                int index = StringUtils.indexOfIgnoreCase(sql, " " + key);
                if (index == -1) {
                    index = StringUtils.indexOfIgnoreCase(sql, key);
                } else {
                    index++;// 补偿空格
                }
                key1 = StringUtils.substring(sql, index, index + key.length());
            }

            result.put(key1, map.get(key));
        }
        return result;
    }
}
